{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyTorch: Training ResNet50 using the AISShardReader and WebDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We train the ResNet50 model using dummy ImageNet data sharded in the WebDataset format. Note that you can download the actual ImageNet dataset and use that instead if you would like."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1) Import necessary packages, define constants, and create AIS Client."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    from aistore.sdk import Client\n",
    "    from aistore.pytorch.shard_reader import AISShardReader\n",
    "except:\n",
    "\n",
    "    # Use local version of aistore if pip version is too old or aistore not installed\n",
    "    import sys\n",
    "\n",
    "    sys.path.append(\"../../\")\n",
    "\n",
    "    from aistore.sdk import Client\n",
    "    from aistore.pytorch.shard_reader import AISShardReader\n",
    "\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.models import resnet50\n",
    "from torch.utils.data import DataLoader\n",
    "from torch import nn, optim, no_grad, max, stack, tensor\n",
    "from random import shuffle\n",
    "from PIL import Image\n",
    "from io import BytesIO\n",
    "\n",
    "import requests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "AIS_ENDPOINT = \"http://localhost:8080\"\n",
    "AIS_PROVIDER = \"ais\"\n",
    "BCK_NAME = \"fake-imagenet\"\n",
    "DATASET_URL = \"https://storage.googleapis.com/webdataset/fake-imagenet/imagenet-train-{000000..001281}.tar\"\n",
    "\n",
    "client = Client(endpoint=AIS_ENDPOINT)\n",
    "bucket = client.bucket(BCK_NAME, AIS_PROVIDER).create(exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2) Populate the bucket with WebDataset formatted shards using the AIS CLI.\n",
    "\n",
    "or do `ais start download \"https://storage.googleapis.com/webdataset/fake-imagenet/imagenet-train-{000000..001281}.tar\" ais://fake-imagenet`.\n",
    "\n",
    "The dataset comes from the official WebDataset example notebooks: https://github.com/webdataset/webdataset/blob/main/examples/train-resnet50-wds.ipynb."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading dataset...\n",
      "Done putting files into buckets.\n",
      "Cleaned up downloaded dataset.\n"
     ]
    }
   ],
   "source": [
    "bucket = client.bucket(BCK_NAME, AIS_PROVIDER)\n",
    "try:\n",
    "    bucket.create(exist_ok=False)\n",
    "\n",
    "    print(\"Downloading dataset...\")\n",
    "\n",
    "    headers = {\"User-Agent\": \"Mozilla/5.0\"}\n",
    "\n",
    "    for i in range(1282):\n",
    "        tar_url = DATASET_URL.replace(\"{000000..001281}\", f\"{i:06}\")\n",
    "        name = tar_url.split(\"/\")[-1]\n",
    "\n",
    "        response = requests.get(tar_url, headers=headers, stream=True)\n",
    "        response.raise_for_status()\n",
    "\n",
    "        data = BytesIO(response.content)\n",
    "        bucket.object(name).get_writer().put_content(data.read())\n",
    "\n",
    "    print(\"Done putting files into buckets.\")\n",
    "\n",
    "    print(\"Cleaned up downloaded dataset.\")\n",
    "except Exception as e:\n",
    "    print(\"Bucket already has dataset! Nothing will be done.\")\n",
    "    bucket.create(exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3) Generate random split of indices and pass to ShardReader."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_SPLIT = 0.80\n",
    "NUM_SHARDS = 100  # 1281 is total number, we will take subset to save time\n",
    "\n",
    "shard_indices = list(range(NUM_SHARDS))\n",
    "shuffle(shard_indices)\n",
    "\n",
    "train_boundary = int(len(shard_indices) * 0.8)\n",
    "\n",
    "train_indices = shard_indices[:train_boundary]\n",
    "validation_indices = shard_indices[train_boundary:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_shards = AISShardReader(\n",
    "    bucket_list=bucket,\n",
    "    prefix_map={\n",
    "        bucket: [f\"imagenet-train-{index:06}.tar\" for index in train_indices]\n",
    "    },  # :06 because each number has two digits prepended\n",
    ")\n",
    "\n",
    "validation_shards = AISShardReader(\n",
    "    bucket_list=bucket,\n",
    "    prefix_map={\n",
    "        bucket: [f\"imagenet-train-{index:06}.tar\" for index in validation_indices]\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 4) Create DataLoader and pass in parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 100\n",
    "NUM_WORKERS = 16\n",
    "\n",
    "train_loader = DataLoader(train_shards, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)\n",
    "\n",
    "validation_loader = DataLoader(\n",
    "    validation_shards, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 5) Define model, hyperparameters, optimizer, transforms, and loss function for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LEARNING_RATE = 0.1\n",
    "WEIGHT_DECAY = 5e-4\n",
    "\n",
    "resnet_model = resnet50()\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(\n",
    "    resnet_model.parameters(),\n",
    "    lr=0.01,\n",
    "    momentum=LEARNING_RATE,\n",
    "    weight_decay=WEIGHT_DECAY,\n",
    ")\n",
    "\n",
    "transform_train = transforms.Compose(\n",
    "    [\n",
    "        transforms.Resize(256),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 6) Train the model on a number of epochs and validate accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_EPOCHS = 1\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    print(f\"EPOCH {epoch + 1}\\n-----------\")\n",
    "\n",
    "    loss = 0\n",
    "    i = 0\n",
    "    for i, (_, contents) in enumerate(train_loader):\n",
    "\n",
    "        images = stack(\n",
    "            [\n",
    "                transform_train(Image.open(BytesIO(image_bytes)))\n",
    "                for image_bytes in contents[\"jpg\"]\n",
    "            ]\n",
    "        )\n",
    "        labels = tensor(\n",
    "            [int(label_bytes.decode(\"utf-8\")) for label_bytes in contents[\"cls\"]]\n",
    "        )\n",
    "\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = resnet_model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # print statistics\n",
    "        loss += loss.item()\n",
    "\n",
    "        print(f\"Batch: {i + 1}, Samples Processed: {(i + 1) * BATCH_SIZE}\")\n",
    "\n",
    "    print(f\"Samples Processed: {(i+1) * BATCH_SIZE}, Loss: {loss / 100}\")\n",
    "\n",
    "    # Validation\n",
    "    resnet_model.eval()\n",
    "    with no_grad():\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        for _, contents in validation_loader:\n",
    "\n",
    "            images = stack(\n",
    "                [\n",
    "                    transform_train(Image.open(BytesIO(image_bytes)))\n",
    "                    for image_bytes in contents[\"jpg\"]\n",
    "                ]\n",
    "            )\n",
    "            labels = tensor(\n",
    "                [int(label_bytes.decode(\"utf-8\")) for label_bytes in contents[\"cls\"]]\n",
    "            )\n",
    "\n",
    "            if len(labels) != BATCH_SIZE:\n",
    "                print(len(labels))\n",
    "            outputs = resnet_model(images)\n",
    "            _, predicted = max(outputs.data, 1)\n",
    "\n",
    "            correct += (predicted == labels).sum().item()\n",
    "            total += len(labels)\n",
    "\n",
    "    print(f\"Accuracy: {100 * correct / total}%\\n\")\n",
    "\n",
    "print(\"-----------\\nFinished Training\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "notebook",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
